import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os,sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from base import env1, Memory
import time

class Actor(nn.Module):
    def __init__(self,n_state,n_action):
        super(Actor, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(n_state,30),
            nn.ReLU(),
            nn.Linear(30,n_action)
        )
        for m in self.seq.parameters():
            if isinstance(m,nn.Linear):
                nn.init.normal_(m.weight,0,0.1)

    def forward(self,state):
        x = self.seq(state)
        x = torch.tanh(x)*2.
        return x

class Critic(nn.Module):
    def __init__(self,n_state,n_action):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(n_state,30)
        self.fc1.weight.data.normal_(0,0.1)
        self.fc2 = nn.Linear(n_action,30)
        self.fc2.weight.data.normal_(0,0.1)
        self.fc3 = nn.Linear(60,60)
        self.fc3.weight.data.normal_(0,0.1)
        self.out = nn.Linear(60,n_action)
        self.out.weight.data.normal_(0,0.1)

    def forward(self,state,action):
        x1 = self.fc1(state)
        x2 = self.fc2(action)
        x = torch.cat([x1,x2],dim=1)
        x = self.fc3(F.relu(x))
        x = self.out(F.relu(x))
        return x

class DDPG(Memory):
    def __init__(self,n_state,n_action,
                 explore_var=3.,explore_var_decay=0.9995,
                 gamma=0.9,TAU=0.01,actor_lr=1e-3,critic_lr=0.002,
                 capacity=10000,logs='./logs'):
        super(DDPG,self).__init__(capacity,logs)
        self.actor = Actor(n_state,n_action).to(self.device)
        self.actor_target = Actor(n_state,n_action).to(self.device)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)

        self.critic = Critic(n_state,n_action).to(self.device)
        self.critic_target = Critic(n_state,n_action).to(self.device)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)

        self.explore_var = explore_var
        self.explore_var_decay = explore_var_decay
        self.gamma = gamma
        self.TAU = TAU

    def choose_action(self,state):
        self.explore_var *= self.explore_var_decay
        state = torch.tensor(state,dtype=torch.float).unsqueeze(0).to(self.device)
        action = self.actor(state).item()
        action = np.random.normal(action,self.explore_var)
        action = np.clip(action,env1.action_space.low,env1.action_space.high)
        return action

    def optimize(self,batch):
        s = torch.tensor(batch.s,dtype=torch.float).to(self.device)
        a = torch.tensor(batch.a,dtype=torch.float).to(self.device)
        r = torch.tensor(batch.r,dtype=torch.float).to(self.device)
        s_ = torch.tensor(batch.s_, dtype=torch.float).to(self.device)
        done = torch.tensor(batch.done, dtype=torch.float).to(self.device)

        q_predicted = self.critic(s,a)
        q_expected = self.critic_target(s_,self.actor_target(s_))
        q_expected = r + self.gamma*(1-done)*q_expected

        critic_loss = F.mse_loss(q_predicted,q_expected).mean()
        self.writer.add_scalar('loss_critic',critic_loss.item(),self.step)

        self.critic_opt.zero_grad()
        critic_loss.backward()
        self.critic_opt.step()

        # 计算损失和反向传播过程要挨在一起，不然会覆盖梯度
        actor_loss = -1.*self.critic(s,self.actor(s)).mean()
        self.writer.add_scalar('loss_actor',actor_loss.item(),self.step)

        self.actor_opt.zero_grad()
        actor_loss.backward()
        self.actor_opt.step()

        for p,tp in zip(self.critic.parameters(),self.critic_target.parameters()):
            tp.data.copy_(self.TAU*p.data + (1-self.TAU)*tp.data)

        for p,tp in zip(self.actor.parameters(),self.actor_target.parameters()):
            tp.data.copy_(self.TAU*p.data + (1-self.TAU)*tp.data)

if __name__=='__main__':
    epoch = 200
    max_step_per_epoch =200
    # update_epoch = 10
    batch_size = 64
    env = env1
    env.seed(int(time.time()))

    model = DDPG(env.observation_space.shape[0],env.action_space.shape[0])

    print("Sampling...")
    count = 0
    while count < model.capacity:
        state = env.reset()
        for _ in range(max_step_per_epoch):
            action = model.choose_action(state)
            state_, reward, done, info = env.step(action)
            model.put_transition(state, action, [reward], state_, [done])
            state = state_
            count += 1

    count = 0
    for ep in range(epoch):
        epoch_r = 0
        state = env.reset()
        for st in range(max_step_per_epoch):
            action = model.choose_action(state)
            state_,reward,done,info = env.step(action)
            model.put_transition(state,action,[reward],state_,[done])
            state = state_

            batch = model.get_transition(batch_size=batch_size)
            model.optimize(batch)
            env.render()

            count += 1
            epoch_r += reward
            print('Epoch:[{}/{}],step:[{}/{}],ep_r:{:.4f}'.format(
                ep + 1, epoch, st + 1, max_step_per_epoch,epoch_r))
        model.writer.add_scalar('ep_r',epoch_r,count)